#%%
import time
import argparse
import numpy as np
import scipy
import torch
import torch.nn.functional as F
import torch.optim as optim
import dgl
from utils import feature_norm
from utils import load_data, accuracy, load_pokec, load_bail
from models import GCN
import torch.nn as nn
from models.GCN import fairness_loss, GCN_Body
def optimize_sup(model, adj, x, label, sens, cv):
    model.train()
    if fair:
        y, h, means1, means2, stds1, stds2, all_std = model(x,adj)
        if args.adversarial=='True':
            model.optimize(h, sens, idx_train, y, label, means1, means2, stds1, stds2, all_std, idx_s0, idx_s1, args.hp1, args.hp2, args.kappa, args.eta, args.hp3, adversarial=True)
        else:
            model.optimize(h, sens, idx_train, y, label, means1, means2, stds1, stds2, all_std, idx_s0, idx_s1, args.hp1, args.hp2, args.kappa, args.eta, args.hp3, adversarial=False)
        print('cov_loss: ', model.cov)
        print('adv_loss: ',model.adv_loss)
        print('relaxed_loss: ',model.relaxed)
        print('fair_norm loss: ', model.fair_norm)
        loss=model.G_loss
    elif slayer:
        print('single layer for baselines')
        y, h = model(x, adj)
        if args.adversarial=='True':
            model.optimize_slayer(h, sens, idx_train, y, label, idx_s0, idx_s1, args.hp1, args.hp2, args.kappa, args.eta, args.hp3, adversarial=True)
        else:
            model.optimize_slayer(h, sens, idx_train, y, label, idx_s0, idx_s1, args.hp1, args.hp2, args.kappa, args.eta, args.hp3, adversarial=False)

        print('cov_loss: ', model.cov)
        print('adv_loss: ',model.adv_loss)
        print('relaxed_loss: ',model.relaxed)
        loss=model.G_loss
    else:
        optimizer.zero_grad()
        criterion = nn.BCEWithLogitsLoss()
        y, h = model(x, adj)
        loss_train = criterion(y[idx_train],labels[idx_train].unsqueeze(1).float())
        loss=loss_train
        loss.backward()
        optimizer.step()
    return loss
def fair_metric(output,idx):
    val_y = labels[idx].cpu().numpy()
    idx_s0 = sens.cpu().numpy()[idx.cpu().numpy()]==0
    idx_s1 = sens.cpu().numpy()[idx.cpu().numpy()]==1

    idx_s0_y1 = np.bitwise_and(idx_s0,val_y==1)
    idx_s1_y1 = np.bitwise_and(idx_s1,val_y==1)

    pred_y = (output[idx].squeeze()>0).type_as(labels).cpu().numpy()
    parity = abs(sum(pred_y[idx_s0])/sum(idx_s0)-sum(pred_y[idx_s1])/sum(idx_s1))
    equality = abs(sum(pred_y[idx_s0_y1])/sum(idx_s0_y1)-sum(pred_y[idx_s1_y1])/sum(idx_s1_y1))

    return parity,equality
# Training settings
parser = argparse.ArgumentParser()
parser.add_argument('--no-cuda', action='store_true', default=False,
                    help='Disables CUDA training.')
parser.add_argument('--fastmode', action='store_true', default=False,
                    help='Validate during training pass.')
parser.add_argument('--seed', type=int, default=42, help='Random seed.')
parser.add_argument('--epochs', type=int, default=1000,
                    help='Number of epochs to train.')
parser.add_argument('--lr', type=float, default=0.001,
                    help='Initial learning rate.')
parser.add_argument('--weight_decay', type=float, default=1e-5,
                    help='Weight decay (L2 loss on parameters).')
parser.add_argument('--hidden', type=int, default=128,
                    help='Number of hidden units of the sensitive attribute estimator')
parser.add_argument('--dropout', type=float, default=0,
                    help='Dropout rate (1 - keep probability).')
parser.add_argument('--alpha', type=float, default=4,
                    help='The hyperparameter of alpha')
parser.add_argument('--kappa', type=float, default=1)
parser.add_argument('--eta', type=float, default=1)
parser.add_argument('--activation',type=str,default='relu')
parser.add_argument('--beta', type=float, default=0.01,
                    help='The hyperparameter of beta')
parser.add_argument('--model', type=str, default="GCN",
                    help='the type of model GCN/GAT')
parser.add_argument('--dataset', type=str, default='pokec',
                    choices=['pokec','pokec2','nba', 'credit', 'bail'])
parser.add_argument('--gpu', type=int, default=0)
parser.add_argument('--fairness', type=str, default='False')
parser.add_argument('--dlayer', type=str, default='False')
parser.add_argument('--slayer', type=str, default='False')
parser.add_argument('--adversarial', type=str, default='False')
parser.add_argument('--num-hidden', type=int, default=64,
                    help='Number of hidden units of classifier.')
parser.add_argument("--num-heads", type=int, default=1,
                        help="number of hidden attention heads")
parser.add_argument("--num-out-heads", type=int, default=1,
                    help="number of output attention heads")
parser.add_argument("--num-layers", type=int, default=2,
                    help="number of hidden layers")
parser.add_argument("--residual", action="store_true", default=False,
                    help="use residual connection")
parser.add_argument("--attn-drop", type=float, default=.0,
                    help="attention dropout")
parser.add_argument('--negative-slope', type=float, default=0.2,
                    help="the negative slope of leaky relu")
parser.add_argument('--acc', type=float, default=0.688,
                    help='the selected FairGNN accuracy on val would be at least this high')
parser.add_argument('--roc', type=float, default=0.745,
                    help='the selected FairGNN ROC score on val would be at least this high')
parser.add_argument('--sens_number', type=int, default=200,
                    help="the number of sensitive attributes")
parser.add_argument('--norm', type=str, default='bn')
parser.add_argument('--hp1',type=float,default=1)
parser.add_argument('--hp2',type=float,default=1)
parser.add_argument('--hp3',type=float,default=1)
parser.add_argument('--hp4',type=float,default=1)
parser.add_argument('--label_number', type=int, default=500,
                    help="the number of labels")

args = parser.parse_known_args()[0]
args.cuda = not args.no_cuda and torch.cuda.is_available()

gpu_id=args.gpu
np.random.seed(args.seed)
torch.manual_seed(args.seed)
if args.cuda:
    torch.cuda.manual_seed(args.seed)
torch.backends.cudnn.benchmark = False
torch.backends.cudnn.deterministic = True
torch.use_deterministic_algorithms(True)
torch.cuda.manual_seed_all(args.seed) 

activation = ({'relu': F.relu, 'prelu': nn.PReLU(), 'tanh': nn.Tanh(), 'sigmoid':nn.Sigmoid()})[args.activation]

overall_acc=np.zeros((5,args.epochs))
test_acc=np.zeros(5)
test_roc=np.zeros(5)
test_sp=np.zeros(5)
test_eo=np.zeros(5)
for cv in range(5):
    if args.dataset =='pokec' or args.dataset=='pokec2':
        if args.dataset == 'pokec':
            dataset = 'region_job'
        elif args.dataset == 'pokec2':
            dataset = 'region_job_2'
        sens_attr = "region"
        predict_attr = "I_am_working_in_field"
        label_number = args.label_number
        sens_number = args.sens_number
        seed = 20+cv
        path="../dataset/pokec/"
        test_idx=False
        adj, features, labels, idx_train, idx_val, idx_test,sens = load_pokec(dataset,
                                                                      sens_attr,
                                                                      predict_attr,
                                                                      path=path,
                                                                      label_number=label_number,
                                                                      sens_number=sens_number,
                                                                      seed=seed,test_idx=test_idx)
        features = feature_norm(features)
    elif args.dataset == 'bail':
        dataset='bail'
        sens_attr = "WHITE"
        predict_attr = "RECID"
        path = "../dataset/bail/"
        seed= 20+cv
        adj, features, labels, idx_train, idx_val, idx_test,sens = load_bail(dataset,
                                                                      sens_attr,
                                                                      predict_attr,
                                                                      path=path,
                                                                      seed=seed)
        features = feature_norm(features)

    labels[labels>1]=1
    sens[sens>0]=1
    # Model and optimizer
    idx_s0 = sens[idx_train]==0                                                                                                                                                                                                                              
    idx_s1 = sens[idx_train]==1    
    if args.fairness == 'True':                                                                                                 
        fair=True
        node_ids_s0=np.where(sens==0)[0]
        node_ids_s1=np.where(sens==1)[0]
    else:
        node_ids_s0=np.where(sens==0)[0]
        node_ids_s1=np.where(sens==1)[0]
        fair=False
    print('fairness is: ',fair)

    if args.dlayer=='True':
        dlayer=True
    else:
        dlayer=False
    if args.slayer=='True':
        slayer=True
    else:
        slayer=False
    model = GCN_Body(features.shape[1], args.num_hidden, 1, args.dropout, activation, node_ids_s0, node_ids_s1, fairness=fair, dlayer=dlayer, k=args.num_layers, norm_type=args.norm, gpu=args.gpu)
    
    optimizer=torch.optim.Adam(model.parameters(), lr = args.lr, weight_decay = args.weight_decay)
    if args.cuda:
        model.cuda(device=gpu_id)
        adj = adj.cuda(device=gpu_id)
        features = features.cuda(device=gpu_id)
        labels = labels.cuda(device=gpu_id)
        idx_train = idx_train.cuda(device=gpu_id)
        idx_val = idx_val.cuda(device=gpu_id)
        idx_test = idx_test.cuda(device=gpu_id)
        sens = sens.cuda(device=gpu_id)
    
    

    from sklearn.metrics import accuracy_score,roc_auc_score,recall_score,f1_score


    # Train model
    t_total = time.time()
    best_result = {}
    best_fair = 100
    best_acc=0
    
    for epoch in range(args.epochs):
        t = time.time()
        
        train_loss=optimize_sup(model, adj, features, labels, sens, 0)
        
        model.eval()
        if fair:
            output, adv, means1, means2, stds1, stds2, all_std=model(features, adj)
        else:
            output, adv = model(features,adj)
        acc_val = accuracy(output[idx_val], labels[idx_val])
        roc_val = roc_auc_score(labels[idx_val].cpu().numpy(),output[idx_val].detach().cpu().numpy())


        
        parity_val, equality_val = fair_metric(output,idx_val)
        acc_train=accuracy(output[idx_train], labels[idx_train])
        overall_acc[cv,epoch]=acc_train
        acc_test = accuracy(output[idx_test], labels[idx_test])
        roc_test = roc_auc_score(labels[idx_test].cpu().numpy(),output[idx_test].detach().cpu().numpy())
        parity,equality = fair_metric(output,idx_test)
    
        if best_acc < acc_val:
            best_acc=acc_val
            best_result['acc'] = acc_test.item()
            best_result['roc'] = roc_test
            best_result['parity'] = parity
            best_result['equality'] = equality

        print("=================================")

        print('Epoch: {:04d}'.format(epoch+1),
              'loss: {:.4f}'.format(train_loss),
              'acc_val: {:.4f}'.format(acc_val.item()),
              "roc_val: {:.4f}".format(roc_val),
              "parity_val: {:.4f}".format(parity_val),
              "equality: {:.4f}".format(equality_val))
        print("Test:",
              "accuracy: {:.4f}".format(acc_test.item()),
              "roc: {:.4f}".format(roc_test),
              "parity: {:.4f}".format(parity),
              "equality: {:.4f}".format(equality))

    print("Optimization Finished!")
    print("Total time elapsed: {:.4f}s".format(time.time() - t_total))

    print('============performace on test set=============')
    if len(best_result) > 0:
        test_acc[cv]=best_result['acc']
        test_roc[cv]=best_result['roc']
        test_sp[cv]=best_result['parity']
        test_eo[cv]=best_result['equality']
        print("Test:",
              "accuracy: {:.4f}".format(best_result['acc']),
              "roc: {:.4f}".format(best_result['roc']),
              "parity: {:.4f}".format(best_result['parity']),
              "equality: {:.4f}".format(best_result['equality']))
    else:
        print("Please set smaller acc/roc thresholds")

    np.random.seed(args.seed)
    torch.manual_seed(args.seed)
np.save('disc_hp1_'+str(args.hp1)+'_hp2_'+str(args.hp2)+'_eta_'+str(args.eta)+'_kappa_'+str(args.kappa)+'_fair_'+str(args.fairness)+'_dlayer_'+str(args.dlayer)+'_'+args.norm+'_'+args.activation+'_supervised_'+args.dataset+'.npy',np.mean(overall_acc,axis=0))
np.save('disc_std_'+str(args.hp1)+'_hp2_'+str(args.hp2)+'_eta_'+str(args.eta)+'_kappa_'+str(args.kappa)+'_fair_'+str(args.fairness)+'_dlayer_'+str(args.dlayer)+'_'+args.norm+'_'+args.activation+'_supervised_'+args.dataset+'.npy',np.std(overall_acc,axis=0))
print("Overall:",
      "accuracy: {:.4f}".format(np.mean(test_acc)),
      "acc std: {:.4f}".format(np.std(test_acc)),
      "roc: {:.4f}".format(np.mean(test_roc)),
      "roc std: {:.4f}".format(np.std(test_roc)),
      "parity: {:.4f}".format(np.mean(test_sp)),
      "sp std: {:.4f}".format(np.std(test_sp)),
      "equality: {:.4f}".format(np.mean(test_eo)),
      "eo std: {:.4f}".format(np.std(test_eo)),)
